// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.HttpSys.Internal;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.HttpSys.Listener
{
    internal static class Utilities
    {
        internal static readonly int WriteRetryLimit = 1000;
        internal static readonly byte[] WriteBuffer = new byte[1024 * 1024];

        // When tests projects are run in parallel, overlapping port ranges can cause a race condition when looking for free
        // ports during dynamic port allocation.
        private const int BasePort = 8001;
        private const int MaxPort = 11000;
        private static int NextPort = BasePort;
        private static object PortLock = new object();

        internal static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(15);

        internal static HttpSysListener CreateHttpAuthServer(AuthenticationSchemes authType, bool allowAnonymous, out string baseAddress)
        {
            var listener = CreateHttpServer(out baseAddress);
            listener.Options.Authentication.Schemes = authType;
            listener.Options.Authentication.AllowAnonymous = allowAnonymous;
            return listener;
        }

        internal static HttpSysListener CreateHttpServer(out string baseAddress)
        {
            string root;
            return CreateDynamicHttpServer(string.Empty, out root, out baseAddress);
        }

        internal static HttpSysListener CreateHttpServerReturnRoot(string path, out string root)
        {
            string baseAddress;
            return CreateDynamicHttpServer(path, out root, out baseAddress);
        }

        internal static HttpSysListener CreateDynamicHttpServer(string basePath, out string root, out string baseAddress)
        {
            lock (PortLock)
            {
                while (NextPort < MaxPort)
                {
                    var port = NextPort++;
                    var prefix = UrlPrefix.Create("http", "localhost", port, basePath);
                    root = prefix.Scheme + "://" + prefix.Host + ":" + prefix.Port;
                    baseAddress = prefix.ToString();
                    var options = new HttpSysOptions();
                    options.UrlPrefixes.Add(prefix);
                    options.RequestQueueName = prefix.Port; // Convention for use with CreateServerOnExistingQueue
                    var listener = new HttpSysListener(options, new LoggerFactory());
                    try
                    {
                        listener.Start();
                        return listener;
                    }
                    catch (HttpSysException ex)
                    {
                        listener.Dispose();
                        if (ex.ErrorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_ALREADY_EXISTS
                            && ex.ErrorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SHARING_VIOLATION
                            && ex.ErrorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_ACCESS_DENIED)
                        {
                            throw;
                        }
                    }
                }
                NextPort = BasePort;
            }
            throw new Exception("Failed to locate a free port.");
        }

        internal static HttpSysListener CreateHttpsServer()
        {
            return CreateServer("https", "localhost", 9090, string.Empty);
        }

        internal static HttpSysListener CreateServer(string scheme, string host, int port, string path)
        {
            var listener = new HttpSysListener(new HttpSysOptions(), new LoggerFactory());
            listener.Options.UrlPrefixes.Add(UrlPrefix.Create(scheme, host, port, path));
            listener.Start();
            return listener;
        }

        internal static HttpSysListener CreateServerOnExistingQueue(string requestQueueName)
        {
            return CreateServerOnExistingQueue(AuthenticationSchemes.None, true, requestQueueName);
        }

        internal static HttpSysListener CreateServerOnExistingQueue(AuthenticationSchemes authScheme, bool allowAnonymos, string requestQueueName)
        {
            var options = new HttpSysOptions();
            options.RequestQueueMode = RequestQueueMode.Attach;
            options.RequestQueueName = requestQueueName;
            options.Authentication.Schemes = authScheme;
            options.Authentication.AllowAnonymous = allowAnonymos;
            var listener = new HttpSysListener(options, new LoggerFactory());
            listener.Start();
            return listener;
        }

        /// <summary>
        /// AcceptAsync extension with timeout. This extension should be used in all tests to prevent
        /// unexpected hangs when a request does not arrive.
        /// </summary>
        internal static async Task<RequestContext> AcceptAsync(this HttpSysListener server, TimeSpan timeout)
        {
            var factory = new TestRequestContextFactory(server);
            using var acceptContext = new AsyncAcceptContext(server, factory);
            
            async Task<RequestContext> AcceptAsync()
            {
                while (true)
                {
                    var requestContext = await server.AcceptAsync(acceptContext);

                    if (server.ValidateRequest(requestContext))
                    {
                        requestContext.InitializeFeatures();
                        return requestContext;
                    }

                    requestContext.ReleasePins();
                    requestContext.Dispose();
                }
            }

            var acceptTask = AcceptAsync();
            var completedTask = await Task.WhenAny(acceptTask, Task.Delay(timeout));

            if (completedTask == acceptTask)
            {
                return await acceptTask;
            }
            else
            {
                server.Dispose();
                throw new TimeoutException("AcceptAsync has timed out.");
            }
        }

        // Fail if the given response task completes before the given accept task.
        internal static async Task<RequestContext> Before<T>(this Task<RequestContext> acceptTask, Task<T> responseTask)
        {
            var completedTask = await Task.WhenAny(acceptTask, responseTask);

            if (completedTask == acceptTask)
            {
                return await acceptTask;
            }
            else
            {
                var response = await responseTask;
                throw new InvalidOperationException("The response completed prematurely: " + response.ToString());
            }
        }

        private class TestRequestContextFactory : IRequestContextFactory
        {
            private readonly HttpSysListener _server;

            public TestRequestContextFactory(HttpSysListener server)
            {
                _server = server;
            }

            public RequestContext CreateRequestContext(uint? bufferSize, ulong requestId)
            {
                return new RequestContext(_server, bufferSize, requestId);
            }
        }
    }
}
